# SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
#

# APIs for interpreting and creating protobuf packets for
# protocomm endpoint with security type protocomm_security1

import proto
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PublicKey
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers import algorithms
from cryptography.hazmat.primitives.ciphers import modes
from utils import long_to_bytes
from utils import str_to_bytes

from .security import Security


def a_xor_b(a: bytes, b: bytes) -> bytes:
    return b''.join(long_to_bytes(a[i] ^ b[i]) for i in range(0, len(b)))


# Enum for state of protocomm_security1 FSM
class security_state:
    REQUEST1 = 0
    RESPONSE1_REQUEST2 = 1
    RESPONSE2 = 2
    FINISHED = 3


class Security1(Security):
    def __init__(self, pop, verbose):
        # Initialize state of the security1 FSM
        self.session_state = security_state.REQUEST1
        self.pop = str_to_bytes(pop)
        self.verbose = verbose
        Security.__init__(self, self.security1_session)

    def security1_session(self, response_data):
        # protocomm security1 FSM which interprets/forms
        # protobuf packets according to present state of session
        if self.session_state == security_state.REQUEST1:
            self.session_state = security_state.RESPONSE1_REQUEST2
            return self.setup0_request()
        elif self.session_state == security_state.RESPONSE1_REQUEST2:
            self.session_state = security_state.RESPONSE2
            self.setup0_response(response_data)
            return self.setup1_request()
        elif self.session_state == security_state.RESPONSE2:
            self.session_state = security_state.FINISHED
            self.setup1_response(response_data)
            return None

        print('Unexpected state')
        return None

    def __generate_key(self):
        # Generate private and public key pair for client
        self.client_private_key = X25519PrivateKey.generate()
        self.client_public_key = self.client_private_key.public_key().public_bytes(
            encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
        )

    def _print_verbose(self, data):
        if self.verbose:
            print(f'\x1b[32;20m++++ {data} ++++\x1b[0m')

    def setup0_request(self):
        # Form SessionCmd0 request packet using client public key
        setup_req = proto.session_pb2.SessionData()
        setup_req.sec_ver = proto.session_pb2.SecScheme1
        self.__generate_key()
        setup_req.sec1.sc0.client_pubkey = self.client_public_key
        self._print_verbose(f'Client Public Key:\t0x{self.client_public_key.hex()}')
        return setup_req.SerializeToString().decode('latin-1')

    def setup0_response(self, response_data):
        # Interpret SessionResp0 response packet
        setup_resp = proto.session_pb2.SessionData()
        setup_resp.ParseFromString(str_to_bytes(response_data))
        self._print_verbose('Security version:\t' + str(setup_resp.sec_ver))
        if setup_resp.sec_ver != proto.session_pb2.SecScheme1:
            raise RuntimeError('Incorrect security scheme')

        self.device_public_key = setup_resp.sec1.sr0.device_pubkey
        # Device random is the initialization vector
        device_random = setup_resp.sec1.sr0.device_random
        self._print_verbose(f'Device Public Key:\t0x{self.device_public_key.hex()}')
        self._print_verbose(f'Device Random:\t0x{device_random.hex()}')

        # Calculate Curve25519 shared key using Client private key and Device public key
        sharedK = self.client_private_key.exchange(X25519PublicKey.from_public_bytes(self.device_public_key))
        self._print_verbose(f'Shared Key:\t0x{sharedK.hex()}')

        # If PoP is provided, XOR SHA256 of PoP with the previously
        # calculated Shared Key to form the actual Shared Key
        if len(self.pop) > 0:
            # Calculate SHA256 of PoP
            h = hashes.Hash(hashes.SHA256(), backend=default_backend())
            h.update(self.pop)
            digest = h.finalize()
            # XOR with and update Shared Key
            sharedK = a_xor_b(sharedK, digest)
            self._print_verbose(f'Updated Shared Key (Shared key XORed with PoP):\t0x{sharedK.hex()}')
        # Initialize the encryption engine with Shared Key and initialization vector
        cipher = Cipher(algorithms.AES(sharedK), modes.CTR(device_random), backend=default_backend())
        self.cipher = cipher.encryptor()

    def setup1_request(self):
        # Form SessionCmd1 request packet using encrypted device public key
        setup_req = proto.session_pb2.SessionData()
        setup_req.sec_ver = proto.session_pb2.SecScheme1
        setup_req.sec1.msg = proto.sec1_pb2.Session_Command1
        # Encrypt device public key and attach to the request packet
        client_verify = self.cipher.update(self.device_public_key)
        self._print_verbose(f'Client Proof:\t0x{client_verify.hex()}')
        setup_req.sec1.sc1.client_verify_data = client_verify
        return setup_req.SerializeToString().decode('latin-1')

    def setup1_response(self, response_data):
        # Interpret SessionResp1 response packet
        setup_resp = proto.session_pb2.SessionData()
        setup_resp.ParseFromString(str_to_bytes(response_data))
        # Ensure security scheme matches
        if setup_resp.sec_ver == proto.session_pb2.SecScheme1:
            # Read encrypyed device verify string
            device_verify = setup_resp.sec1.sr1.device_verify_data
            self._print_verbose(f'Device Proof:\t0x{device_verify.hex()}')
            # Decrypt the device verify string
            enc_client_pubkey = self.cipher.update(setup_resp.sec1.sr1.device_verify_data)
            # Match decrypted string with client public key
            if enc_client_pubkey != self.client_public_key:
                raise RuntimeError('Failed to verify device!')
        else:
            raise RuntimeError('Unsupported security protocol')

    def encrypt_data(self, data):
        return self.cipher.update(data)

    def decrypt_data(self, data):
        return self.cipher.update(data)
